{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# KNN-K近邻法\n",
    "是一种基本的分类与回归算法，K近邻法输入为实例的特征向量，对应于空间的点，输出为实例的类别"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn import datasets\n",
    "from utils.metric import euclidean_distance,accuracy_score\n",
    "from utils.progress import normalize,train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class KNN(object):\n",
    "    def __init__(self,k=5):\n",
    "        self.k = 5\n",
    "    def _vote(self,neighbours):\n",
    "        counts = np.bincount(neighbours[:, 1].astype('int'))\n",
    "        return counts.argmax()\n",
    "    def predict(self, X_test, X_train, y_train):\n",
    "        y_pred = np.empty(X_test.shape[0])\n",
    "        # 对每一个test进行循环\n",
    "        for i,test in enumerate(X_test):\n",
    "            neighbours = np.empty((X_train.shape[0],2))\n",
    "            # 对每一个train进行计算\n",
    "            for j, train in enumerate(X_train):\n",
    "                dis = euclidean_distance(train,test)\n",
    "                label = y_train[j]\n",
    "                neighbours[j] = [dis,label]\n",
    "            k_nearest_neighbors = neighbours[neighbours[:,0].argsort()][:self.k]\n",
    "            label = self._vote(k_nearest_neighbors)\n",
    "            y_pred[i] = label\n",
    "        return y_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = datasets.load_iris()\n",
    "X = normalize(data.data)\n",
    "y = data.target\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf = KNN(k=5)\n",
    "y_pred = clf.predict(X_test, X_train, y_train)\n",
    "accuracy = accuracy_score(y_test, y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.97959183673469385"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 参考\n",
    "1. 《统计学习方法》\n",
    "2. 《机器学习》（周志华）\n",
    "3. [blog](https://www.cnblogs.com/ybjourney/p/4702562.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
